Skip to content

fix(perf): prevent fp32 promotion in model hot paths#20

Merged
inureyes merged 1 commit into
mainfrom
fix/perf-fp32-promotion-models
May 18, 2026
Merged

fix(perf): prevent fp32 promotion in model hot paths#20
inureyes merged 1 commit into
mainfrom
fix/perf-fp32-promotion-models

Conversation

@inureyes
Copy link
Copy Markdown
Member

Summary

Generalises the bf16/f16 dtype-preservation pattern that landed for gpt-oss in #17 to every MoE / hot-path model in the tree. The same FP32 promotion that crushed gpt-oss decode (19 → 113 tok/s) was hiding in 25+ other models — anywhere the expert MoE combine, the activation helper, or the router cast back through an einsum / softmax boundary.

Cherry-picked from the internal repo (mlxcel-internal commit 616c4704); the internal docs/model_tests_m5max.md update from the original commit was intentionally excluded (it referenced an internal-only baseline doc that was never in the public repo).

What changed

Shared helper (src/models/switch_layers.rs, +70):

  • New moe_weighted_sum(expert_out, scores, output_dtype) — replaces the old nkh,nk->nh einsum contraction (which promoted to FP32 on M5 for bf16/f16 activations) with mlx-lm's y * scores[..., None] + sum(axis=-2) pattern, with scores cast to the expert output dtype and the final result restored to the hidden/residual dtype.
  • Unit test covering bf16 round-trip.

FFI helpers (mlxcel-core/cpp/mlx_cxx_bridge.cpp + ffi_tests.rs +180):

utils (mlxcel-core/src/utils.rs +25): supporting dtype helpers.

Model call-site updates (25 model files): each MoE/router path now routes its combine through moe_weighted_sum and casts router/expert scores back to the input dtype before residual add. Models touched:

deepseek, deepseek_v3, deepseek_v32, ernie4_5_moe, exaone_moe, glm4_moe,
glm4_moe_lite, gpt_oss, hunyuan_moe, kimi_linear, minimax, mistral4,
mixtral, moondream3, olmoe, phimoe, qwen2_moe, qwen3_5, qwen3_moe,
qwen3_next, qwen3_vl_moe, solar_open, step3p5

Verification

  • make verify-fmt — clean
  • make verify-clippy (CI-faithful: --all-targets --features metal,accelerate -- -D warnings) — clean in 2m04s
  • make verify-test skipped here (15-30 min release-mode run); the underlying commit is already validated in mlxcel-internal against the M5 Max benchmark sweep.

Why a single commit

The 25-model sweep is one logical change — same pattern, same root cause, same fix shape per file. Splitting it would multiply review cost without adding signal.

@inureyes inureyes added status:review Under review type:bug Bug fixes, error corrections, or issue resolutions type:performance Performance improvements priority:high High priority area:models Model architectures, weights, loading, metadata area:core mlxcel-core: MLX FFI, primitives, KV cache, layers labels May 18, 2026
@inureyes inureyes merged commit 8dcc84c into main May 18, 2026
1 check passed
@inureyes inureyes deleted the fix/perf-fp32-promotion-models branch May 18, 2026 11:39
@inureyes inureyes self-assigned this May 18, 2026
@inureyes inureyes added status:done Completed and removed status:review Under review labels May 18, 2026
@inureyes inureyes mentioned this pull request May 18, 2026
3 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

area:core mlxcel-core: MLX FFI, primitives, KV cache, layers area:models Model architectures, weights, loading, metadata priority:high High priority status:done Completed type:bug Bug fixes, error corrections, or issue resolutions type:performance Performance improvements

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant